{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c1278eb0-ad96-4d9e-8db7-eb3644818b9c",
   "metadata": {},
   "source": [
    "# **Jupyter Notebook: DistilBERT CPU FLOPs calculation**\n",
    "\n",
    "## **Block 1: Load Model**\n",
    "This block is responsible for:\n",
    "- **Loading the `DistilBERT` pre-trained model**\n",
    "- **Loading the `DistilBERT` Tokenizer**\n",
    "- **Printing basic model information (number of parameters, layers, hidden size, etc.)**\n",
    "- **Ensuring the model runs on CPU**\n",
    "- **Reading model configuration information**\n",
    "- **Printing model details (number of parameters, layers, hidden size, etc.)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfed0f8f-556b-4e68-9526-9486b1af74f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using PyTorch version: 2.5.1\n",
      "\n",
      "===== DistilBERT Model Info =====\n",
      "Number of Parameters: 66,362,880\n",
      "Number of Transformer Layers: 6\n",
      "Hidden Size: 768\n",
      "FFN Hidden Size (hidden_dim): 3072\n",
      "Number of Attention Heads: 12\n",
      "Embedding Layer dtype: torch.float32\n",
      "Transformer Block 0 FFN dtype: torch.float32\n",
      "Transformer Block 1 Self-Attention dtype: torch.float32\n"
     ]
    }
   ],
   "source": [
    "### Block 1: Load Model ###\n",
    "import torch\n",
    "import time\n",
    "from transformers import DistilBertTokenizer, DistilBertModel\n",
    "\n",
    "# Print PyTorch version \n",
    "print(f\"Using PyTorch version: {torch.__version__}\")\n",
    "\n",
    "# Load the pre-trained model and tokenizer\n",
    "tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
    "model = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n",
    "\n",
    "# Ensure the model runs on CPU\n",
    "device = torch.device(\"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# Read model configuration information\n",
    "config_dict = model.config.to_dict()\n",
    "\n",
    "print(\"\\n===== DistilBERT Model Info =====\")\n",
    "print(f\"Number of Parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
    "print(f\"Number of Transformer Layers: {config_dict['n_layers']}\")\n",
    "print(f\"Hidden Size: {config_dict['dim']}\")\n",
    "print(f\"FFN Hidden Size (hidden_dim): {config_dict['hidden_dim']}\")\n",
    "print(f\"Number of Attention Heads: {config_dict['n_heads']}\")\n",
    "\n",
    "print(f\"Embedding Layer dtype: {model.embeddings.word_embeddings.weight.dtype}\")  # Print data type of embedding layer\n",
    "print(f\"Transformer Block 0 FFN dtype: {model.transformer.layer[0].ffn.lin1.weight.dtype}\")  # Print data type of FFN layer in the first transformer block\n",
    "print(f\"Transformer Block 1 Self-Attention dtype: {model.transformer.layer[1].attention.q_lin.weight.dtype}\")  # Print data type of self-attention layer in the second transformer block"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44bd47c2-3af7-4557-9094-78bb14bcc30d",
   "metadata": {},
   "source": [
    "## **Block 2: Tokenization**\n",
    "\n",
    "- **Converting input text into tokens**\n",
    "- **Printing the tokenized results**\n",
    "- **Checking `input_ids` and `attention_mask`**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2f784745-a509-4725-b252-23c383234fbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===== Tokenization Results =====\n",
      "Input Text: The scientist carefully examined the ancient manuscript, deciphering cryptic symbols that revealed a forgotten civilization’s history, including detailed maps, complex astronomical calculations, and mysterious prophecies about future events, leaving researchers astonished and eager to uncover deeper secrets hidden within the fragile pages of the centuries-old document stored inside a\n",
      "Tokenized Tokens: ['the', 'scientist', 'carefully', 'examined', 'the', 'ancient', 'manuscript', ',', 'dec', '##ip', '##hering', 'cryptic', 'symbols', 'that', 'revealed', 'a', 'forgotten', 'civilization', '’', 's', 'history', ',', 'including', 'detailed', 'maps', ',', 'complex', 'astronomical', 'calculations', ',', 'and', 'mysterious', 'prop', '##he', '##cies', 'about', 'future', 'events', ',', 'leaving', 'researchers', 'astonished', 'and', 'eager', 'to', 'uncover', 'deeper', 'secrets', 'hidden', 'within', 'the', 'fragile', 'pages', 'of', 'the', 'centuries', '-', 'old', 'document', 'stored', 'inside', 'a']\n",
      "Tokens Length: 64\n",
      "Token IDs: tensor([[  101,  1996,  7155,  5362,  8920,  1996,  3418,  8356,  1010, 11703,\n",
      "         11514, 22658, 26483,  9255,  2008,  3936,  1037,  6404, 10585,  1521,\n",
      "          1055,  2381,  1010,  2164,  6851,  7341,  1010,  3375, 13674, 16268,\n",
      "          1010,  1998,  8075, 17678,  5369,  9243,  2055,  2925,  2824,  1010,\n",
      "          2975,  6950, 22741,  1998,  9461,  2000, 26944,  6748,  7800,  5023,\n",
      "          2306,  1996, 13072,  5530,  1997,  1996,  4693,  1011,  2214,  6254,\n",
      "          8250,  2503,  1037,   102]])\n",
      "Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n"
     ]
    }
   ],
   "source": [
    "### Block 2: Tokenization ###\n",
    "# Define input text\n",
    "text = \"The scientist carefully examined the ancient manuscript, deciphering cryptic symbols that revealed a forgotten civilization’s history, including detailed maps, complex astronomical calculations, and mysterious prophecies about future events, leaving researchers astonished and eager to uncover deeper secrets hidden within the fragile pages of the centuries-old document stored inside a\"\n",
    "\n",
    "# Perform tokenization\n",
    "inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "\n",
    "# Print tokenized results\n",
    "print(\"\\n===== Tokenization Results =====\")\n",
    "print(f\"Input Text: {text}\")\n",
    "print(f\"Tokenized Tokens: {tokenizer.tokenize(text)}\")\n",
    "print(f\"Tokens Length: {inputs.input_ids.size()[1]}\")\n",
    "print(f\"Token IDs: {inputs['input_ids']}\")\n",
    "print(f\"Attention Mask: {inputs['attention_mask']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c340dcbf-dcda-4581-b53c-f73d83f9a36f",
   "metadata": {},
   "source": [
    "## **Block 3: Embedding Calculation**\n",
    "\n",
    "- **Convert token IDs into 768-dimensional embedding vectors**\n",
    "- **Check the shape of the `Embedding` output**\n",
    "- **Compute `Embedding` FLOPs (computational cost)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b204082c-99f0-4b2c-a7bd-9545a958e244",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===== Token Embedding Results =====\n",
      "Token IDs Shape: torch.Size([1, 64])\n",
      "Embedding Output Shape: torch.Size([1, 64, 768])\n",
      "Embedding FLOPs: 49,152 FLOPs\n"
     ]
    }
   ],
   "source": [
    "### Block 3: Compute Token Embedding ###\n",
    "# Get the model's embedding layer\n",
    "embedding_layer = model.embeddings.word_embeddings\n",
    "\n",
    "# Compute embeddings for token IDs\n",
    "token_embeddings = embedding_layer(inputs[\"input_ids\"])\n",
    "\n",
    "# Print results\n",
    "print(\"\\n===== Token Embedding Results =====\")\n",
    "print(f\"Token IDs Shape: {inputs['input_ids'].shape}\")\n",
    "print(f\"Embedding Output Shape: {token_embeddings.shape}\")  # (batch_size, seq_len, hidden_dim)\n",
    "\n",
    "# Compute FLOPs\n",
    "seq_length = token_embeddings.shape[1]\n",
    "hidden_dim = token_embeddings.shape[2]\n",
    "embedding_flops = seq_length * hidden_dim\n",
    "\n",
    "print(f\"Embedding FLOPs: {embedding_flops:,} FLOPs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88fdd47d-2e9f-470a-aa0c-9c23f3210b4d",
   "metadata": {},
   "source": [
    "## **Block 4: Self-Attention**\n",
    "\n",
    "- **Calculate Query (Q), Key (K), Value (V) Projection**\n",
    "- **Calculate Multihead`QK^T`（Attention Scores）**\n",
    "- **Calculate Multihead`Softmax(QK^T) V`（Weighted Attention）**\n",
    "- **Calculate Output (O) Projection**\n",
    "- **Sum Self-Attention FLOPs**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "814d04eb-b30b-4061-bcaa-9e723611c509",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===== Multi-Head Self-Attention Calculation Results =====\n",
      "Multi-Head Self-Attention FLOPs: 314.43 MFLOPs\n"
     ]
    }
   ],
   "source": [
    "### Block 4: Multi-Head Self-Attention Calculation ###\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# Extract the first layer of DistilBERT Transformer\n",
    "transformer_layer = model.transformer.layer[0]\n",
    "\n",
    "# Get parameters from the model configuration\n",
    "num_heads = model.config.n_heads  # 12 Attention heads\n",
    "hidden_dim = model.config.dim  # 768\n",
    "d_k = hidden_dim // num_heads  # Calculate the dimension of each Attention head\n",
    "\n",
    "# Get Self-Attention weights\n",
    "W_Q = transformer_layer.attention.q_lin.weight\n",
    "W_K = transformer_layer.attention.k_lin.weight\n",
    "W_V = transformer_layer.attention.v_lin.weight\n",
    "W_O = transformer_layer.attention.out_lin.weight  # Final projection of MHA\n",
    "\n",
    "# Calculate MatMul FLOPs (Multiply)\n",
    "flops_qkv_mul= 3 * seq_length * hidden_dim * hidden_dim  # Q, K, V projections\n",
    "flops_qk_t_mul = num_heads * seq_length * seq_length * d_k  # QK^T\n",
    "flops_attn_v_mul = num_heads * seq_length * seq_length * d_k  # Softmax(QK^T) * V\n",
    "flops_output_proj_mul = seq_length * hidden_dim * hidden_dim  # W_O projection\n",
    "\n",
    "# Calculate MatMul FLOPs (Add)\n",
    "flops_qkv_add = 3 * seq_length * hidden_dim * (hidden_dim - 1) # Q, K, V projections\n",
    "flops_qk_t_add = num_heads * seq_length * seq_length * (d_k - 1) # QK^T\n",
    "flops_attn_v_add = num_heads * seq_length * seq_length * (d_k - 1)  # Softmax(QK^T) * V\n",
    "flops_output_proj_add = seq_length * hidden_dim * (hidden_dim - 1)  # W_O projection\n",
    "\n",
    "# Calculate total MatMul FLOPs for Attention\n",
    "flops_qkv = flops_qkv_mul + flops_qkv_add\n",
    "flops_qk_t = flops_qk_t_mul + flops_qk_t_add\n",
    "flops_attn_v = flops_attn_v_mul + flops_attn_v_add\n",
    "flops_output_proj = flops_output_proj_mul + flops_output_proj_add\n",
    "\n",
    "# Calculate Softmax FLOPs\n",
    "softmax_exp = num_heads * seq_length * seq_length  # Exponential calculation\n",
    "softmax_exp_add = num_heads * seq_length * seq_length  # Add the Exponential results\n",
    "softmax_div = num_heads * seq_length * seq_length  # Division by sum of Exponential\n",
    "flops_softmax = softmax_exp + softmax_exp_add + softmax_div # Total Softmax FLOPs\n",
    "\n",
    "# Total FLOPs calculation\n",
    "total_mha_flops = flops_qkv + flops_qk_t + flops_attn_v + flops_output_proj + flops_softmax\n",
    "\n",
    "print(\"\\n===== Multi-Head Self-Attention Calculation Results =====\")\n",
    "print(f\"Multi-Head Self-Attention FLOPs: {total_mha_flops / 1e6:.2f} MFLOPs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "355df891-cd61-41c8-ba67-197470b83e24",
   "metadata": {},
   "source": [
    "## **Block 5: Feed Forward Network (FFN) Calculation**\n",
    "\n",
    "- **Compute `FFN(x) = ReLU(xW1 + b1) W2 + b2`**\n",
    "- **Check the shape of the `FFN` output**\n",
    "- **Count `FFN` FLOPs (computational cost)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e98f5926-f56b-48bb-b7a9-b64429f405df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===== Feed Forward Network (FFN) Calculation Results =====\n",
      "Feed Forward Network (FFN) FLOPs: 604.52 MFLOPs\n"
     ]
    }
   ],
   "source": [
    "### Block 5: Calculate Feed Forward Network (FFN) ###\n",
    "# Get FFN weights\n",
    "W1 = transformer_layer.ffn.lin1.weight\n",
    "W2 = transformer_layer.ffn.lin2.weight\n",
    "\n",
    "intermediate_size = model.config.hidden_dim  # Internal FFN dimension (3072)\n",
    "\n",
    "# Calculate FLOPs\n",
    "# First Linear Layer (768 -> 3072)\n",
    "flops_ffn1_mul = seq_length * hidden_dim * intermediate_size  # MatMul\n",
    "flops_ffn1_add = seq_length * intermediate_size * (hidden_dim - 1)  # Add\n",
    "flops_ffn1 = flops_ffn1_mul + flops_ffn1_add\n",
    "\n",
    "# GELU Activation\n",
    "flops_gelu = seq_length * intermediate_size * 4  # GELU computation (approximately 4 FLOPs per element)\n",
    "\n",
    "# Second Linear Layer (3072 -> 768)\n",
    "flops_ffn2_mul = seq_length * intermediate_size * hidden_dim  # MatMul\n",
    "flops_ffn2_add = seq_length * hidden_dim * (intermediate_size - 1)  # Add\n",
    "flops_ffn2 = flops_ffn2_mul + flops_ffn2_add\n",
    "\n",
    "# Total FLOPs calculation\n",
    "total_ffn_flops = flops_ffn1 + flops_gelu + flops_ffn2\n",
    "\n",
    "print(\"\\n===== Feed Forward Network (FFN) Calculation Results =====\")\n",
    "print(f\"Feed Forward Network (FFN) FLOPs: {total_ffn_flops / 1e6:.2f} MFLOPs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "738ac22b-831c-4cd8-9ccd-019594f1536d",
   "metadata": {},
   "source": [
    "## **Block 6: Layer Normalization & Residual Connection**\n",
    "\n",
    "- **Compute `LayerNorm(x + residual)`**\n",
    "- **Check the shape of the `LayerNorm` output**\n",
    "- **Count `LayerNorm` FLOPs (computational cost)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "08a83d5d-c130-4f72-a848-f2aa1dc7c126",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===== LayerNorm & Residual Calculation Results =====\n",
      "LayerNorm & Residual FLOPs: 0.89 MFLOPs\n"
     ]
    }
   ],
   "source": [
    "### Block 6: Compute LayerNorm and Residual Connection FLOPs ###\n",
    "# Compute LayerNorm FLOPs (2 times, 6147 FLOPs per token)\n",
    "flops_layernorm = 2 * seq_length * 6147\n",
    "\n",
    "# Compute Residual Connection FLOPs (2 times, 768 FLOPs per token)\n",
    "flops_residual = 2 * seq_length * hidden_dim\n",
    "\n",
    "# Total FLOPs\n",
    "total_layernorm_residual_flops = flops_layernorm + flops_residual\n",
    "\n",
    "print(\"\\n===== LayerNorm & Residual Calculation Results =====\")\n",
    "print(f\"LayerNorm & Residual FLOPs: {total_layernorm_residual_flops / 1e6:.2f} MFLOPs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46cbf1d5-d76a-4e70-9c22-17cd2700ee56",
   "metadata": {},
   "source": [
    "## **Block 7: Transformer Layer Overall Computation Statistics**\n",
    "This block is responsible for:\n",
    "- **Summarizing the total computation of Self-Attention, FFN, and LayerNorm**\n",
    "- **Calculating the FLOPs of a Transformer layer**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "14368dde-bdaf-4f1d-b007-4cfca9307281",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===== Token Embedding Results =====\n",
      "Token IDs Shape: torch.Size([1, 64])\n",
      "Embedding Output Shape: torch.Size([1, 64, 768])\n",
      "Total FLOPs Per-Layer ≈ 919.880 MFLOPs\n",
      "Total FLOPs Model ≈ 5.519 GFLOPs\n"
     ]
    }
   ],
   "source": [
    "### Block 3: Compute Token Embedding ###\n",
    "# Get the model's embedding layer\n",
    "embedding_layer = model.embeddings.word_embeddings\n",
    "\n",
    "# Compute embeddings for token IDs\n",
    "token_embeddings = embedding_layer(inputs[\"input_ids\"])\n",
    "\n",
    "# Print results\n",
    "print(\"\\n===== Token Embedding Results =====\")\n",
    "print(f\"Token IDs Shape: {inputs['input_ids'].shape}\")\n",
    "print(f\"Embedding Output Shape: {token_embeddings.shape}\")  # (batch_size, seq_len, hidden_dim)\n",
    "\n",
    "# Compute FLOPs\n",
    "seq_length = token_embeddings.shape[1]\n",
    "hidden_dim = token_embeddings.shape[2]\n",
    "embedding_flops = seq_length * hidden_dim  # Embedding FLOPs\n",
    "\n",
    "# Compute total FLOPs (including Residual, LayerNorm, MHA, FFN)\n",
    "total_distilbert_flops = (\n",
    "    embedding_flops + total_mha_flops + total_ffn_flops + total_layernorm_residual_flops\n",
    ")\n",
    "\n",
    "print(f\"Total FLOPs Per-Layer ≈ {total_distilbert_flops / 1e6:.3f} MFLOPs\")\n",
    "print(f\"Total FLOPs Model ≈ {total_distilbert_flops / 1e9 * config_dict['n_layers']:.3f} GFLOPs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97877684-f437-416b-ba8a-9d0cd7f9b6c1",
   "metadata": {},
   "source": [
    "## **Block 8: CPU Direct Execution Time**\n",
    "- **Measure CPU execution time**\n",
    "- **Calculate `DistilBERT` FLOPs per second on the CPU**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "99bbe1e6-a891-4c41-bf28-2754a781bcae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===== CPU Execution Time Statistics  =====\n",
      "Input Token Shape: torch.Size([1, 64])\n",
      "Output Shape: torch.Size([1, 64, 768])\n",
      "Total Execution Time: 0.034835 sec\n",
      "FLOPs per Second (GFLOPs/s): 26.41 GFLOPs/s\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "from transformers import DistilBertTokenizer, DistilBertModel\n",
    "\n",
    "tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
    "model = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n",
    "\n",
    "device = torch.device(\"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "\n",
    "# Run the full DistilBERT inference again and measure execution time\n",
    "start_time = time.time()\n",
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "end_time = time.time()\n",
    "\n",
    "execution_time = end_time - start_time\n",
    "\n",
    "print(f\"\\n===== CPU Execution Time Statistics  =====\")\n",
    "print(f\"Input Token Shape: {inputs['input_ids'].shape}\")\n",
    "print(f\"Output Shape: {outputs.last_hidden_state.shape}\")\n",
    "print(f\"Total Execution Time: {execution_time:.6f} sec\")\n",
    "print(f\"FLOPs per Second (GFLOPs/s): {total_distilbert_flops / execution_time / 1e9:.2f} GFLOPs/s\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
